*! version 1.1, 29jan2019
*! Peter Hull (hull@uchicago.edu)
program define ivvam, eclass
   version 11.0
   syntax anything, enroll(varlist) [cluster(varname)] [dofc]

   /* create unique identifier, if no clusters */
   tempvar id
   if "`cluster'"=="" {
	   gen `id' = _n
   }
   else {
	   gen `id' = `cluster'
   }
   sort `id'
   qui count
   local N=r(N)

   /* parse iv syntax */
   _iv_parse `0'
   local cmd `cmd'
   local lhs `s(lhs)'
   local endog `s(endog)'
   local exog `s(exog)'
   local inst `s(inst)'
   local 0 `s(zero)'

   /* partial out controls */
   qui _regress `lhs' `exog'
   local C = e(df_m)
   tempvar fw_lhs
   predict `fw_lhs', r
   local fw_endog
   local M=0
   local eq ""
   if "`endog'" != "" {
   foreach var of varlist `endog' {
	   qui _regress `var' `exog'
	   tempvar fw_`var'
	   predict `fw_`var'', r
	   local fw_endog `fw_endog' `fw_`var''
	   local M=`M'+1
	   local eq "`eq' Predictors"
   }
   }
   local fw_inst
   local L=0
   local reg
   foreach var of varlist `inst' {
	   qui _regress `var' `exog'
	   tempvar fw_`var'
	   predict `fw_`var'', r
	   local fw_inst `fw_inst' `fw_`var''
	   local L=`L'+1
       local reg = `reg' + 1
   }

   local k_mbtsls = (1 - `C'/`N') / (1 - `L'/`N' - `C'/`N')

	   /* estimate VAM prior mean */
           qui ivreg2 `fw_lhs' (`fw_endog'=`fw_inst'), nocons liml
           matrix mu_hat_liml=e(b)
		   qui ivreg2 `fw_lhs' (`fw_endog'=`fw_inst'), nocons savefirst kclass(`k_mbtsls')
		   matrix mu_hat=e(b)
           matrix V_hat=e(V)

	   tempvar u v q
	   predict `u', r
	   qui reg `u' `fw_inst', nocons
	   matrix U_mss=e(mss)

       predict `q', r
       qui su `q'
       matrix adj_num = r(sd)^2 * (`L' - `M')

	   predict `v', r
	   if "`endog'" != "" {
	   matrix fs=J(`L',`M',0)
	   local j=1
		   foreach var of varlist `fw_endog' {
			   qui est restore _ivreg2_`var'
			   matrix temp=e(b)
			   matrix fs[1,`j']=temp'
			   local j=`j'+1
		   }
	   }

	   /* estimate VAM residual variance */
	   mata Z=st_data(.,"`fw_inst'")
	   mata D=st_data(.,"`enroll'")
	   mata ZZ=Z'*Z
	   mata invZZ=invsym(ZZ)
	   mata pihat=invZZ*(Z'*D)
	   mata fs=st_matrix("fs")
	   mata st_matrix("ZZ",ZZ)
	   mata st_matrix("invZZ",invZZ)
	   mata st_matrix("pihat",pihat)
	   if "`dofc'" != "" & "`endog'" != "" {
           matrix U = U_mss - adj_num
		   mata R=pihat'*Z'*(Z*invZZ*Z'-Z*fs*invsym(fs'*ZZ*fs)*fs'*Z')*Z*pihat
           matrix P=trace((ZZ)*(pihat*pihat'))
		   mata st_matrix("R",R)
           matrix adj_denom = trace(R) - trace(P)
		   matrix P=trace(R)
	   }
	   else {
           matrix U = U_mss
		   matrix P=trace((ZZ)*(pihat*pihat'))
	   }

       /* compute SEs */
           mata Y=st_data(.,"`fw_lhs'")
           mata T=st_data(.,"`fw_endog'")
           mata YT = Y, T
           mata YPY = YT'*Z*invZZ*Z'*YT
           mata YMY = YT'*YT - YT'*Z*invZZ*Z'*YT
           mata Sp = YMY / (`N'-`L'-`C')
           mata S = YPY / `N'
           mata mu_hat = st_matrix("mu_hat")
           mata Gamma = I(`=`M'+1')
           forval i=2/`=`M'+1'{
               local m = `i' - 1
               mata Gamma[`i',1] = -1*mu_hat[1,`m']
           }

           /* estimate Lambda_11 */
           mata b  = 1 \ -1*mu_hat'
           mata temp = b' * (S - `L'/`N' * Sp) * b
           mata Lambda_11 = max(0 \ temp[1,1])

           /* estimate Lambda_22 */
           mata mmin = min(Re(eigenvalues(invsym(Sp)*S)))
           mata st_matrix("mmin",mmin)
           local mmin = mmin[1,1]
           local mat_size = `M'+1

           mata lambda_re = max(Re(eigenvalues(invsym(Sp)*S))) - `L'/`N'
           mata mu_hat_liml = st_matrix("mu_hat_liml")
           mata a = mu_hat_liml' \ 1
           mata Omega_re = (`N'-`L'-`C')/(`N'-`C')*Sp + `N'/(`N'-`C')*(S - lambda_re/(a'*invsym(Sp)*a)*(a*a'))

           if `mmin' > `L'/`N' {
               mata Omega_ure = Sp
               mata Lambda_22 = S[2..`mat_size',2..`mat_size'] - `L'/`N' * Omega_ure[2..`mat_size',2..`mat_size']
           }
           else {
               mata Omega_ure = Omega_re
               mata Lambda_22 = lambda_re/(a'invsym(Omega_ure)*a) * I(`M')
           }

           /* estimate Sigma */
           mata Sigma = Gamma' * Omega_re * Gamma

           /* estimate VCE */
           local h = (1 - `C'/`N') * `L'/`N' / (1 - `C'/`N' - `L'/`N')
           mata Vvalid = Sigma[1,1]*invsym(Lambda_22) + `h'*invsym(Lambda_22)*(Sigma[1,1]*Sigma[2..`=`M'+1',2..`=`M'+1']+Sigma[2..`=`M'+1',1]*Sigma[2..`=`M'+1',1]')*invsym(Lambda_22)
           mata Vinvalid = Vvalid + invsym(Lambda_22)*(Lambda_11*Sigma[2..`=`M'+1',2..`=`M'+1'] + Lambda_11*Lambda_22*`N'/`L')*invsym(Lambda_22)
           mata mu_VCE = Vinvalid:/`N'
           mata st_matrix("mu_VCE",mu_VCE)
           mat colnames mu_VCE = `endog'
           mat rownames mu_VCE = `endog'
           local sigma_VE = 0

           /* estimate of sigma_nu^2 */
           mata st_matrix("Lambda_11",Lambda_11)
           mat sigma2_hat = Lambda_11 / (trace(P) / `N')
           local sigma2_hat = sigma2_hat[1,1]
           if `sigma2_hat' > 0 {
               local sigma_hat=sqrt(`sigma2_hat')
           }
           else {
               local sigma_hat=0
           }

           /* if sigma_nu == 0, recalculate VCE */
           if `sigma_hat'==0 {
               qui ivreg2 `lhs' (`endog' = `inst') `exog', partial(`exog') kclass(`k_mbtsls')
               mat V_hat = e(V)
               matrix mu_VCE=V_hat
               mat colnames mu_VCE = `endog'
               mat rownames mu_VCE = `endog'
           }


   /* return output */
   if "`endog'" != "" {
	   matrix b=[mu_hat,`sigma_hat']
	   matrix colnames b=`endog' sigma
	   mat coleq b=`eq' Residual
	   matrix V=[mu_VCE,J(`M',1,0)\J(1,`M',0),`sigma_VE']
	   matrix colnames V=`endog' sigma
	   matrix rownames V=`endog' sigma
	   mat coleq V=`eq' Residual
	   mat roweq V=`eq' Residual
   }
   else {
	   matrix b=[`sigma_hat']
	   matrix colnames b=sigma
	   mat coleq b=Residual
	   matrix V=[`sigma_VE']
	   matrix colnames V=sigma
	   matrix rownames V=sigma
	   mat coleq V=Residual
	   mat roweq V=Residual
   }
   ereturn post b V
   ereturn local sigma_hat = `sigma_hat'
   ereturn local sigma_VE = `sigma_VE'
   ereturn local depvar `lhs'
   ereturn local endvar `endog'
   ereturn local enroll `enroll'
   ereturn local instrm `inst'
   ereturn local exovar `exog'
   ereturn scalar N=`N'

   di as txt _newline "IV-VAM (random effect) estimation"
   _coef_table_header
   if "`hom'"!="" {
	   ereturn local vcetype "Robust"
   }
   _coef_table
   if "`cluster'"!="" {
	   ereturn local vcetype "Clustered"
	   qui tab `cluster'
	   local N_clust=r(r)
	   di as txt "Standard errors adjusted for " `N_clust' as txt " clusters in `cluster'"
	   ereturn local N_clust=`N_clust'
   }

end
